// Copyright Epic Games, Inc. All Rights Reserved.

#include "rpcreplay_cmd.h"

#include <zencore/compactbinarybuilder.h>
#include <zencore/filesystem.h>
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
#include <zencore/process.h>
#include <zencore/scopeguard.h>
#include <zencore/session.h>
#include <zencore/stream.h>
#include <zencore/timer.h>
#include <zencore/workthreadpool.h>
#include <zenhttp/httpcommon.h>
#include <zenutil/cache/rpcrecording.h>
#include <zenutil/packageformat.h>

ZEN_THIRD_PARTY_INCLUDES_START
#include <cpr/cpr.h>
#include <fmt/format.h>
#include <gsl/gsl-lite.hpp>
ZEN_THIRD_PARTY_INCLUDES_END

#include <memory>

namespace zen {

using namespace std::literals;

RpcStartRecordingCommand::RpcStartRecordingCommand()
{
	m_Options.add_options()("h,help", "Print help");
	m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>");
	m_Options.add_option("", "p", "path", "Recording file path", cxxopts::value(m_RecordingPath), "<path>");

	m_Options.parse_positional("path");
}

RpcStartRecordingCommand::~RpcStartRecordingCommand() = default;

int
RpcStartRecordingCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
	ZEN_UNUSED(GlobalOptions, argc, argv);
	if (!ParseOptions(argc, argv))
	{
		return 0;
	}

	m_HostName = ResolveTargetHostSpec(m_HostName);

	if (m_HostName.empty())
	{
		throw zen::OptionParseException("unable to resolve server specification");
	}

	if (m_RecordingPath.empty())
	{
		throw zen::OptionParseException("Rpc start recording command requires a path");
	}

	cpr::Session Session;
	Session.SetUrl(fmt::format("{}/z$/exec$/start-recording"sv, m_HostName));
	Session.SetParameters({{"path", m_RecordingPath}});
	cpr::Response Response = Session.Post();
	ZEN_CONSOLE("{}", FormatHttpResponse(Response));
	return MapHttpToCommandReturnCode(Response);
}

////////////////////////////////////////////////////

RpcStopRecordingCommand::RpcStopRecordingCommand()
{
	m_Options.add_options()("h,help", "Print help");
	m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>");
}

RpcStopRecordingCommand::~RpcStopRecordingCommand() = default;

int
RpcStopRecordingCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
	ZEN_UNUSED(GlobalOptions, argc, argv);

	if (!ParseOptions(argc, argv))
	{
		return 0;
	}

	m_HostName = ResolveTargetHostSpec(m_HostName);

	if (m_HostName.empty())
	{
		throw zen::OptionParseException("unable to resolve server specification");
	}

	cpr::Session Session;
	Session.SetUrl(fmt::format("{}/z$/exec$/stop-recording"sv, m_HostName));
	cpr::Response Response = Session.Post();
	ZEN_CONSOLE("{}", FormatHttpResponse(Response));
	return MapHttpToCommandReturnCode(Response);
}

////////////////////////////////////////////////////

RpcReplayCommand::RpcReplayCommand()
{
	m_Options.add_options()("h,help", "Print help");
	m_Options.add_option("", "u", "hosturl", "Host URL", cxxopts::value(m_HostName)->default_value(""), "<hosturl>");
	m_Options.add_option("", "p", "path", "Recording file path", cxxopts::value(m_RecordingPath), "<path>");
	m_Options.add_option("", "", "dry", "Do a dry run", cxxopts::value(m_DryRun), "<enable>");
	m_Options.add_option("",
						 "w",
						 "numthreads",
						 "Number of worker threads per process",
						 cxxopts::value(m_ThreadCount)->default_value(fmt::format("{}", std::thread::hardware_concurrency())),
						 "<count>");
	m_Options.add_option("", "", "onhost", "Replay on host, bypassing http/network layer", cxxopts::value(m_OnHost), "<onhost>");
	m_Options.add_option("",
						 "",
						 "showmethodstats",
						 "Show statistics of which RPC methods are used",
						 cxxopts::value(m_ShowMethodStats),
						 "<showmethodstats>");
	m_Options.add_option("",
						 "",
						 "offset",
						 "Offset into request recording to start replay",
						 cxxopts::value(m_Offset)->default_value("0"),
						 "<offset>");
	m_Options.add_option("",
						 "",
						 "stride",
						 "Stride for request recording when replaying requests",
						 cxxopts::value(m_Stride)->default_value("1"),
						 "<stride>");
	m_Options.add_option("", "", "numproc", "Number of worker processes", cxxopts::value(m_ProcessCount)->default_value("1"), "<count>");
	m_Options.add_option("",
						 "",
						 "forceallowlocalrefs",
						 "Force enable local refs in requests",
						 cxxopts::value(m_ForceAllowLocalRefs),
						 "<enable>");
	m_Options
		.add_option("", "", "disablelocalrefs", "Force disable local refs in requests", cxxopts::value(m_DisableLocalRefs), "<enable>");
	m_Options.add_option("",
						 "",
						 "forceallowlocalhandlerefs",
						 "Force enable local refs as handles in requests",
						 cxxopts::value(m_ForceAllowLocalHandleRef),
						 "<enable>");
	m_Options.add_option("",
						 "",
						 "disablelocalhandlerefs",
						 "Force disable local refs as handles in requests",
						 cxxopts::value(m_DisableLocalHandleRefs),
						 "<enable>");
	m_Options.add_option("",
						 "",
						 "forceallowpartiallocalrefs",
						 "Force enable local refs for all sizes",
						 cxxopts::value(m_ForceAllowPartialLocalRefs),
						 "<enable>");
	m_Options.add_option("",
						 "",
						 "disablepartiallocalrefs",
						 "Force disable local refs for all sizes",
						 cxxopts::value(m_DisablePartialLocalRefs),
						 "<enable>");

	m_Options.parse_positional("path");
}

RpcReplayCommand::~RpcReplayCommand() = default;

int
RpcReplayCommand::Run(const ZenCliOptions& GlobalOptions, int argc, char** argv)
{
	ZEN_UNUSED(GlobalOptions, argc, argv);

	if (!ParseOptions(argc, argv))
	{
		return 0;
	}

	m_HostName = ResolveTargetHostSpec(m_HostName);

	if (m_HostName.empty())
	{
		throw zen::OptionParseException("unable to resolve server specification");
	}

	if (m_RecordingPath.empty())
	{
		throw zen::OptionParseException("Rpc replay command requires a path");
	}

	if (!std::filesystem::exists(m_RecordingPath) || !std::filesystem::is_directory(m_RecordingPath))
	{
		throw std::runtime_error(fmt::format("could not find recording at '{}'", m_RecordingPath));
	}

	m_ThreadCount = Max(m_ThreadCount, 1);

	Stopwatch TotalTimer;

	if (m_OnHost)
	{
		cpr::Session Session;
		Session.SetUrl(fmt::format("{}/z$/exec$/replay-recording"sv, m_HostName));
		Session.SetParameters({{"path", m_RecordingPath}, {"thread-count", fmt::format("{}", m_ThreadCount)}});
		cpr::Response Response = Session.Post();
		ZEN_CONSOLE("{}", FormatHttpResponse(Response));
		return MapHttpToCommandReturnCode(Response);
	}

	std::unique_ptr<cache::IRpcRequestReplayer> Replayer   = cache::MakeDiskRequestReplayer(m_RecordingPath, true);
	uint64_t									EntryCount = Replayer->GetRequestCount();

	std::atomic_uint64_t EntryOffset   = m_Offset;
	std::atomic_uint64_t BytesSent	   = 0;
	std::atomic_uint64_t BytesReceived = 0;

	Stopwatch Timer;

	if (m_ProcessCount > 1)
	{
		std::vector<std::unique_ptr<ProcessHandle>> WorkerProcesses;
		WorkerProcesses.resize(m_ProcessCount);

		ProcessMonitor Monitor;
		for (int ProcessIndex = 0; ProcessIndex < m_ProcessCount; ++ProcessIndex)
		{
			std::string CommandLine =
				fmt::format("{} rpc-record-replay --hosturl {} --path \"{}\" --offset {} --stride {} --numthreads {} --numproc {}"sv,
							argv[0],
							m_HostName,
							m_RecordingPath,
							m_Stride == 1 ? 0 : m_Offset + ProcessIndex,
							m_Stride,
							m_ThreadCount,
							1);
			CreateProcResult Result(CreateProc(std::filesystem::path(std::string(argv[0])), CommandLine));
			WorkerProcesses[ProcessIndex] = std::make_unique<ProcessHandle>();
			WorkerProcesses[ProcessIndex]->Initialize(Result);
			Monitor.AddPid(WorkerProcesses[ProcessIndex]->Pid());
		}
		while (Monitor.IsRunning())
		{
			ZEN_CONSOLE("Waiting for worker processes...");
			Sleep(1000);
		}
		return 0;
	}
	else
	{
		std::map<std::string, size_t> MethodTypes;
		RwLock						  MethodTypesLock;

		WorkerThreadPool WorkerPool(m_ThreadCount);

		Latch WorkLatch(m_ThreadCount);
		for (int WorkerIndex = 0; WorkerIndex < m_ThreadCount; ++WorkerIndex)
		{
			WorkerPool.ScheduleWork(
				[this, &WorkLatch, EntryCount, &EntryOffset, &Replayer, &BytesSent, &BytesReceived, &MethodTypes, &MethodTypesLock]() {
					auto _ = MakeGuard([&WorkLatch]() { WorkLatch.CountDown(); });

					std::map<std::string, size_t> LocalMethodTypes;

					auto ReduceTypes = MakeGuard([&] {
						RwLock::ExclusiveLockScope __(MethodTypesLock);

						for (auto& Entry : LocalMethodTypes)
						{
							MethodTypes[Entry.first] += Entry.second;
						}
					});

					cpr::Session Session;
					Session.SetUrl(fmt::format("{}/z$/$rpc"sv, m_HostName));

					uint64_t EntryIndex = EntryOffset.fetch_add(m_Stride);
					while (EntryIndex < EntryCount)
					{
						IoBuffer							  Payload;
						const zen::cache::RecordedRequestInfo RequestInfo = Replayer->GetRequest(EntryIndex, /* out */ Payload);

						if (RequestInfo != zen::cache::RecordedRequestInfo::NullRequest)
						{
							CbPackage RequestPackage;
							CbObject  Request;

							switch (RequestInfo.ContentType)
							{
								case ZenContentType::kCbPackage:
									{
										if (ParsePackageMessageWithLegacyFallback(Payload, RequestPackage))
										{
											Request = RequestPackage.GetObject();
										}
									}
									break;
								case ZenContentType::kCbObject:
									{
										Request = LoadCompactBinaryObject(Payload);
									}
									break;
							}

							RpcAcceptOptions OriginalAcceptOptions = static_cast<RpcAcceptOptions>(Request["AcceptFlags"sv].AsUInt16(0u));
							int				 OriginalProcessPid	   = Request["Pid"sv].AsInt32(0);

							int				 AdjustedPid		   = 0;
							RpcAcceptOptions AdjustedAcceptOptions = RpcAcceptOptions::kNone;

							if (!m_DisableLocalRefs)
							{
								if (EnumHasAnyFlags(OriginalAcceptOptions, RpcAcceptOptions::kAllowLocalReferences) ||
									m_ForceAllowLocalRefs)
								{
									AdjustedAcceptOptions |= RpcAcceptOptions::kAllowLocalReferences;
									if (!m_DisablePartialLocalRefs)
									{
										if (EnumHasAnyFlags(OriginalAcceptOptions, RpcAcceptOptions::kAllowPartialLocalReferences) ||
											m_ForceAllowPartialLocalRefs)
										{
											AdjustedAcceptOptions |= RpcAcceptOptions::kAllowPartialLocalReferences;
										}
									}
									if (!m_DisableLocalHandleRefs)
									{
										if (OriginalProcessPid != 0 || m_ForceAllowLocalHandleRef)
										{
											AdjustedPid = GetCurrentProcessId();
										}
									}
								}
							}

							if (m_ShowMethodStats)
							{
								std::string MethodName = std::string(Request["Method"sv].AsString());
								if (auto It = LocalMethodTypes.find(MethodName); It != LocalMethodTypes.end())
								{
									It->second++;
								}
								else
								{
									LocalMethodTypes[MethodName] = 1;
								}
							}

							if (OriginalAcceptOptions != AdjustedAcceptOptions || OriginalProcessPid != AdjustedPid)
							{
								CbObjectWriter RequestCopyWriter;
								for (const CbFieldView& Field : Request)
								{
									if (!Field.HasName())
									{
										RequestCopyWriter.AddField(Field);
										continue;
									}
									std::string_view FieldName = Field.GetName();
									if (FieldName == "Pid"sv)
									{
										continue;
									}
									if (FieldName == "AcceptFlags"sv)
									{
										continue;
									}
									RequestCopyWriter.AddField(FieldName, Field);
								}
								if (AdjustedPid != 0)
								{
									RequestCopyWriter.AddInteger("Pid"sv, AdjustedPid);
								}
								if (AdjustedAcceptOptions != RpcAcceptOptions::kNone)
								{
									RequestCopyWriter.AddInteger("AcceptFlags"sv, static_cast<uint16_t>(AdjustedAcceptOptions));
								}

								if (RequestInfo.ContentType == ZenContentType::kCbPackage)
								{
									RequestPackage.SetObject(RequestCopyWriter.Save());
									std::vector<IoBuffer>	  Buffers = FormatPackageMessage(RequestPackage);
									std::vector<SharedBuffer> SharedBuffers(Buffers.begin(), Buffers.end());
									Payload = CompositeBuffer(std::move(SharedBuffers)).Flatten().AsIoBuffer();
								}
								else
								{
									RequestCopyWriter.Finalize();
									Payload = IoBuffer(RequestCopyWriter.GetSaveSize());
									RequestCopyWriter.Save(Payload.GetMutableView());
								}
							}

							if (!m_DryRun)
							{
								StringBuilder<32> SessionIdString;

								if (RequestInfo.SessionId != Oid::Zero)
								{
									RequestInfo.SessionId.ToString(SessionIdString);
								}
								else
								{
									GetSessionId().ToString(SessionIdString);
								}

								Session.SetHeader({{"Content-Type", std::string(MapContentTypeToString(RequestInfo.ContentType))},
												   {"Accept", std::string(MapContentTypeToString(RequestInfo.AcceptType))},
												   {"UE-Session", std::string(SessionIdString)}});

								uint64_t Offset		  = 0;
								auto	 ReadCallback = [&Payload, &Offset](char* buffer, size_t& size, intptr_t) {
									size						   = Min<size_t>(size, Payload.GetSize() - Offset);
									IoBuffer		  PayloadRange = IoBuffer(Payload, Offset, size);
									MutableMemoryView Data(buffer, size);
									Data.CopyFrom(PayloadRange.GetView());
									Offset += size;
									return true;
								};
								Session.SetReadCallback(cpr::ReadCallback(gsl::narrow<cpr::cpr_off_t>(Payload.GetSize()), ReadCallback));
								cpr::Response Response = Session.Post();
								BytesSent.fetch_add(Payload.GetSize());
								if (Response.error || !(IsHttpSuccessCode(Response.status_code) ||
														Response.status_code == gsl::narrow<long>(HttpResponseCode::NotFound)))
								{
									ZEN_CONSOLE("{}", FormatHttpResponse(Response));
									break;
								}
								BytesReceived.fetch_add(Response.downloaded_bytes);
							}
						}

						EntryIndex = EntryOffset.fetch_add(m_Stride);
					}
				});
		}

		while (!WorkLatch.Wait(1000))
		{
			const uint64_t RequestsTotal	 = (EntryCount - m_Offset) / m_Stride;
			const uint64_t RequestsRemaining = (EntryCount - EntryOffset.load()) / m_Stride;

			ZEN_CONSOLE("[{:3}%] [{}] {} requests, {} remaining (sent {}, received {})",
						(RequestsTotal - RequestsRemaining) * 100 / RequestsTotal,
						NiceTimeSpanMs(Timer.GetElapsedTimeMs()),
						RequestsTotal,
						RequestsRemaining,
						NiceBytes(BytesSent.load()),
						NiceBytes(BytesReceived.load()));
		}

		if (m_ShowMethodStats)
		{
			for (const auto& It : MethodTypes)
			{
				ZEN_CONSOLE("{:18}: {:10}", It.first, It.second);
			}
		}
	}

	const uint64_t RequestsSent = (EntryOffset.load() - m_Offset) / m_Stride;
	const uint64_t ElapsedMS	= Timer.GetElapsedTimeMs();
	const uint64_t Sent			= BytesSent.load();
	const uint64_t Received		= BytesReceived.load();

	ZEN_CONSOLE("Processed requests: {} ({}), payloads sent {} ({}), payloads received {} ({}) in {}.\nTotal runtime: {}",
				RequestsSent,
				NiceRate(RequestsSent, ElapsedMS, "req"),
				NiceBytes(Sent),
				NiceByteRate(Sent, ElapsedMS),
				NiceBytes(Received),
				NiceByteRate(Received, ElapsedMS),
				NiceTimeSpanMs(ElapsedMS),
				NiceTimeSpanMs(TotalTimer.GetElapsedTimeMs()));

	return 0;
}

}  // namespace zen
